More PyMC

Lecture 20

Dr. Colin Rundel
def offset(l, shift=False):
  res = [x for x in l for _ in range(2)]
  if shift:
    res = [res[0]] + res[:-1]
  return res

def mh_trajectories(trace, colors = ["magenta","green","yellow","blue"], use_offset=True, chains = None):
  
  n = trace.posterior.sizes.get("chain",1)
  if chains is None:
    chains = range(n)

  for i in chains:
    x = trace.posterior["x1"].sel(chain=i).values
    y = trace.posterior["x2"].sel(chain=i).values
    if use_offset:
      x = offset( x )
      y = offset( y, True)

    plt.plot(
      x, y,
      "-o", c=colors[i], #linewidth=0.5, markersize=0.75,
      label=f"Chain {i}", alpha=0.5
    )
  
  plt.legend()

Samplers - Metropolis-Hastings

Algorithm

For a parameter of interest start with an initial value \(\theta_0\) then for iteration \(t+1\),

  1. Generate a proposal value \(\theta'\) from a proposal distribution \(q(x'|x_t)\).

  2. Calculate the acceptance probability, \[ \alpha = \text{min}\left(1, \frac{P(\theta'|x)}{P(\theta_t|x)} \frac{q(\theta_t|\theta')}{q(\theta'|\theta_t)}\right) \]

    where \(P(\theta|x)\) is the target posterior distribution.

  3. Accept proposal \(\theta'\) with probability \(\alpha\), if accepted \(\theta_{t+1} = \theta'\) else \(\theta_{t+1} = \theta\).

Some considerations:

  • Choice of the proposal distribution matters a lot

  • Results are for the limit as \(t \to \infty\)

  • Concerns are around computational efficiency

Banana Distribution

# Data
n = 100
x1_mu = .75
x2_mu = .75
y = pm.draw(pm.Normal.dist(mu=x1_mu+x2_mu**2, sigma=1, shape=n))

# Model
with pm.Model() as banana:
  x1 = pm.Normal("x1", mu=0, sigma=1)
  x2 = pm.Normal("x2", mu=0, sigma=1)

  y = pm.Normal("y", mu=x1+x2**2, sigma=1, observed=y)

  trace = pm.sample(draws=50000, chains=1, random_seed=1234)

Visualizations

Metropolis-Hastings Sampler

with banana:
  mh = pm.sample(
    draws=100, tune=0,
    step=pm.Metropolis([x1,x2]),
    chains=3, random_seed=1234
  )

Chains

mh.posterior["x1"].sel(chain=0).values
array([ 0.     ,  0.     ,  0.     ,  0.     ,  0.     ,  0.     ,  0.     , -0.10305, -0.00038, -0.00038, -0.02686, -0.02686, -0.02686, -0.02686, -0.02686, -0.02686, -0.02686, -0.02686, -0.02686,
       -0.02686, -0.02686, -0.02686, -0.02686, -0.02686, -0.02686, -0.02686, -0.02686, -0.02686,  0.05846,  0.05846,  0.05846,  0.05846,  0.05846,  0.05846, -0.29143, -0.29143, -0.29143, -0.29143,
       -0.29143, -0.29143, -0.29143, -0.29143,  0.13235,  0.13235,  0.13235,  0.13235,  0.13235,  0.13235,  0.13235,  0.13235, -0.05775, -0.05775, -0.05775, -0.05775, -0.05775, -0.05775, -0.05775,
       -0.05775, -0.05775, -0.05775, -0.05775, -0.05775, -0.05775, -0.05775, -0.05775, -0.05775, -0.05775,  0.18542,  0.18542, -0.10349, -0.10349, -0.10349, -0.10349, -0.10349, -0.10349, -0.10349,
       -0.10349, -0.10349, -0.10349, -0.10349, -0.10349, -0.10349, -0.10349, -0.10349, -0.10349, -0.10349, -0.10349, -0.10349, -0.10349, -0.10349, -0.10349, -0.10349, -0.10349, -0.10349, -0.10349,
       -0.10349, -0.10349, -0.10349, -0.10349, -0.10349])
mh.posterior["x2"].sel(chain=0).values
array([0.11441, 1.10949, 1.10949, 1.10949, 1.10949, 1.10949, 1.10949, 1.10949, 1.10949, 1.10949, 1.10949, 1.10949, 1.10949, 1.10949, 1.10949, 1.10949, 1.10949, 1.17123, 1.17123, 1.17123, 1.17123,
       1.17123, 1.17123, 1.17123, 1.17123, 1.17123, 1.17123, 1.17123, 1.17123, 1.17123, 1.17123, 1.17123, 1.17123, 1.17123, 1.17123, 1.17123, 1.17123, 1.17123, 1.17123, 1.17123, 1.17123, 1.17123,
       1.17123, 1.17123, 1.17123, 1.17123, 1.17123, 1.17123, 1.17123, 1.17123, 1.17123, 1.17123, 1.17123, 1.17123, 1.17123, 1.17123, 1.15742, 1.15742, 1.15742, 1.15742, 1.15742, 1.15742, 1.15742,
       1.15742, 1.15742, 1.15742, 1.15742, 1.15742, 1.15742, 1.15742, 1.15742, 1.15742, 1.15742, 1.24326, 1.24326, 1.24326, 1.24326, 1.24326, 1.24326, 1.24326, 1.24326, 1.24326, 1.24326, 1.24326,
       1.24326, 1.24326, 1.24326, 1.24326, 1.24326, 1.24326, 1.24326, 1.24326, 1.24326, 1.24326, 1.24326, 1.24326, 1.24326, 1.24326, 1.24326, 1.27627])

Trajectories

Metropolis-Hastings Sampler with Tuning

with banana:
  mhwt = pm.sample(
    draws=100, tune=1000,
    step=pm.Metropolis([x1,x2]),
    chains=3, random_seed=1234
  )

Chains

mhwt.posterior["x1"].sel(chain=0).values
array([-0.2686 , -0.23502, -0.23502, -0.23502, -0.23502, -0.23502, -0.23502, -0.23502, -0.20543, -0.20543, -0.20543, -0.20543, -0.20543, -0.20543, -0.09611, -0.44032, -0.44032, -0.44032, -0.44032,
       -0.46752, -0.46752, -0.46752, -0.46752, -0.46752, -0.46752, -0.46752, -0.46752, -0.46752, -0.46752, -0.46752, -0.4801 , -0.4801 , -0.33297, -0.33297, -0.33297, -0.39293, -0.39293, -0.39293,
       -0.39293, -0.39293, -0.39293, -0.39293, -0.39293, -0.39293, -0.39293, -0.33753, -0.33753, -0.33753, -0.20055, -0.20055, -0.20055, -0.20055, -0.20055, -0.26092, -0.26092, -0.26092, -0.26092,
       -0.26092, -0.26092, -0.26092, -0.26092, -0.26092, -0.26092, -0.26092, -0.26092, -0.26092, -0.26092, -0.26092, -0.26092, -0.26092, -0.26092, -0.2001 , -0.3036 ,  0.25103, -0.08487, -0.08487,
       -0.08487, -0.08487, -0.08487, -0.08487, -0.08487, -0.08487, -0.08487, -0.08487, -0.08487, -0.08487, -0.08487, -0.08487, -0.08487, -0.08487, -0.08487, -0.08487, -0.08487, -0.01692, -0.01692,
       -0.01692, -0.01692, -0.01692, -0.01692, -0.01692])
mhwt.posterior["x2"].sel(chain=0).values
array([1.31378, 1.31378, 1.31378, 1.31378, 1.25765, 1.25765, 1.25765, 1.25765, 1.25765, 1.25765, 1.25765, 1.25765, 1.25765, 1.25765, 1.25765, 1.25765, 1.25765, 1.34253, 1.34253, 1.34253, 1.34253,
       1.36654, 1.38837, 1.38837, 1.38837, 1.36808, 1.36808, 1.36808, 1.36808, 1.36808, 1.36808, 1.36808, 1.36808, 1.36808, 1.36808, 1.35255, 1.36456, 1.36456, 1.36456, 1.36456, 1.27858, 1.31813,
       1.29092, 1.29092, 1.29092, 1.32343, 1.29965, 1.29965, 1.29965, 1.29965, 1.29965, 1.29965, 1.29965, 1.29965, 1.28582, 1.28582, 1.28582, 1.28582, 1.28582, 1.28582, 1.28582, 1.3012 , 1.3012 ,
       1.3012 , 1.3012 , 1.3012 , 1.3012 , 1.3012 , 1.21153, 1.21153, 1.21153, 1.21153, 1.21153, 1.21153, 1.21153, 1.21153, 1.21153, 1.21153, 1.21153, 1.21153, 1.21153, 1.18757, 1.18757, 1.18757,
       1.18757, 1.18757, 1.18757, 1.18664, 1.18664, 1.18664, 1.18664, 1.18664, 1.18664, 1.18664, 1.18664, 1.18664, 1.22865, 1.22865, 1.13346, 1.13346])

Trajectories

Effects of tuning / burn-in

There are two confounded effects from letting the sampler tune / burn-in:

  1. We have let the sampler run for 1000 iterations - this gives it a chance to find the area’s of higher density and settle in.

This almost makes each chain less sensitive to its initial starting position.

  1. We have also tuned the size of our MH proposals to achieve a better acceptance rate - this lets the chains better explore our target distribution.

More samples?

with banana:
  mh_more = pm.sample(
    draws=1000, tune=1000,
    step=pm.Metropolis([x1,x2]),
    chains=1, random_seed=1234
  )

Even more samples?

with banana:
  mh_more = pm.sample(
    draws=10000, tune=1000,
    step=pm.Metropolis([x1,x2]),
    chains=1, random_seed=1234
  )

mh_more_thin = mh_more.sel(draw=slice(0,None,10))

Bivariate Normal Distribution

# Data
n = 100
y = pm.draw(pm.MvNormal.dist(mu=np.zeros(2), cov=np.eye(2,2), shape=(n,2)))

# Model
with pm.Model() as biv_normal:
  x1 = pm.Normal("x1", mu=0, sigma=1)
  x2 = pm.Normal("x2", mu=0, sigma=1)

  y = pm.MvNormal("y", mu=[x1,x2], cov=np.eye(2,2), observed=y)

  bvn_trace = pm.sample(draws=10000, chains=1, random_seed=1234)

Visualizations

BVM w/ MH

with biv_normal:
  mh_bvn = pm.sample(
    draws=1000, tune=1000,
    step=pm.Metropolis([x1,x2]),
    chains=1, random_seed=1234
  )

Sampler - Hamiltonian Methods

Background

Takes advantage of techniques developed in classical mechanics by imagining our parameters of interest as particles with a position and momentum,

\[ H(\theta, \rho) = -\underset{\text{potential}}{\log p(\theta)} - \underset{\text{kinetic}}{\log p(\rho|\theta)} \]

Hamilton’s equations of motion state give a set of partial differential equations governing the motion of the “particles” in the system.

A numerical integration method known as Leapfrog is then used to evolve the system some number of discrete steps forward in time.

Due to the numerical precision with the leapfrog integrator, a Metropolis acceptance step is typically used, \[ \alpha = \min \left(1, \exp\left( H(\theta, \rho) - H(\theta',\rho') \right) \right) \]

Algorithm parameters

There are a couple of important tuning parameters that are used by Hamiltonian monte carlo methods:

  • \(\epsilon\) is the size of the discrete time steps

  • \(M\) is the mass matrix (or metric) that is used to determine the kinetic energy from the momentum (\(\rho\))

  • \(L\) is the number of leapfrog steps to take per iteration

Generally most of these will be tuned automatically for you by your sampler of choice.

HamiltonianMC

with banana:
  hmc = pm.sample(
    draws=1000, tune=1000,
    step=pm.HamiltonianMC([x1,x2]),
    chains=2, random_seed=1234
  )

mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
x1 0.708 0.564 -0.371 1.472 0.054 0.032 121.0 31.0 1.00
x2 -0.077 0.817 -1.251 1.314 0.084 0.033 70.0 45.0 1.03

No-U-turn sampler (NUTS)

This is a variation of Hamiltonian monte carlo that automatically tunes the number of leapfrog steps to allow more effective exploration of the parameter space.

Specifically, it uses a tree based algorithm that tracks trajectories forwards and backwards in time. The tree expands until a maximum depth is achieved or a “U-turn” is detected.

NUTS also does not use a metropolis step to select the final parameter value, instead the sample is chosen among the valid candidates along the trajectory.

NUTS

with banana:
  nuts = pm.sample(
    draws=1000, tune=1000,
    step=pm.NUTS([x1,x2]),
    chains=2, random_seed=1234
  )

mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
x1 0.781 0.494 -0.195 1.450 0.063 0.035 73.0 61.0 1.02
x2 0.121 0.767 -1.067 1.291 0.102 0.027 50.0 57.0 1.06

Some considerations

  • Hamiltonian MC methods are all very sensitive to the choice of their tuning parameters (NUTS less so, but adds additional parameters)

  • Hamiltonian MC methods require the gradient of the log density of the parameter of interest for the leapfrog integrator - limits this method to continuous parameters

  • HMC updates are generally more expensive computationally than MH updates, but they also tend to produce chains with lower autocorrelation. Best to think about performance in terms of effective samples per unit of time.

Divergent transitions

Using Stan or PyMC with NUTS you will often see messages/ warnings about divergent transitions or divergences.

This is based on the assumption of conservation of energy with regard to the Hamiltonian system - this tells us that \(H(\theta, \rho)\) should remain constant for the “particle” along its trajectory. When \(H(\theta, \rho)\) of the trajectory diverges from its initial value then a divergence is considered to have occurred and positions after that point cannot be considered as the next draw.

The proximate cause of this is a break down of the first order approximations in the leapfrog algorithm.

The ultimate cause is usually a highly curved posterior or a posterior where the rate of curvature is changing rapidly.

Solutions?

Very much depend on the nature of the problem - typically we can potentially reparameterize the model and or adjust some of the tuning parameters to help the sampler deal with the problematic posterior.

For the latter the following options can be passed to pm.sample() or pm.NUTS():

  • target_accept - step size is adjusted to achieve the desired acceptance rate (larger values result in smaller steps which often work better for problematic posteriors)

  • max_treedepth - maximum depth of the trajectory tree

  • step_scale - the initial guess for the step size (scaled down by based on the dimensionality of the parameter space)

NUTS (adjusted)

with banana:
  nuts2 = pm.sample(
    draws=1000, tune=1000,
    step=pm.NUTS([x1,x2], target_accept=0.9),
    chains=2, random_seed=1234
  )

mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
x1 0.578 0.694 -0.740 1.510 0.054 0.041 235.0 180.0 1.01
x2 0.030 0.897 -1.323 1.548 0.073 0.027 125.0 87.0 1.03

Example 1 - Poisson Regression

Data

aids
year cases
0 1981 12
1 1982 14
2 1983 33
3 1984 50
4 1985 67
5 1986 74
6 1987 123
7 1988 141
8 1989 165
9 1990 204
10 1991 253
11 1992 246
12 1993 240

Model

y, X = patsy.dmatrices("cases ~ year", aids)

X_lab = X.design_info.column_names
y = np.asarray(y).flatten()
X = np.asarray(X)

with pm.Model(coords = {"coeffs": X_lab}) as model:
    b = pm.Cauchy("b", alpha=0, beta=1, dims="coeffs")
    η = X @ b
    λ = pm.Deterministic("λ", np.exp(η))
    
    likelihood = pm.Poisson("y", mu=λ, observed=y)
    
    post = pm.sample(random_seed=1234)

Summary

az.summary(post)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
b[Intercept] -9.810600e+01 1.696410e+02 -404.199 1.750000e-01 8.438100e+01 48.793 4.0 45.0 4.00
b[year] 1.220000e-01 7.600000e-02 0.002 2.060000e-01 3.800000e-02 0.018 5.0 44.0 3.74
λ[0] 2.644633e+146 4.581211e+146 23.261 1.057853e+147 2.281425e+146 NaN 4.0 46.0 3.75
λ[1] 3.137637e+146 5.435226e+146 28.930 1.255055e+147 2.706721e+146 NaN 4.0 47.0 3.75
λ[2] 3.722546e+146 6.448444e+146 35.980 1.489018e+147 3.211299e+146 NaN 4.0 43.0 3.75
λ[3] 4.416491e+146 7.650543e+146 44.749 1.766596e+147 3.809940e+146 NaN 4.0 39.0 3.75
λ[4] 5.239800e+146 9.076734e+146 55.654 2.095920e+147 4.520177e+146 NaN 4.0 36.0 3.74
λ[5] 6.216587e+146 1.076879e+147 69.218 2.486635e+147 5.362814e+146 NaN 4.0 33.0 3.74
λ[6] 7.375464e+146 1.277628e+147 86.015 2.950186e+147 6.362534e+146 NaN 4.0 32.0 3.74
λ[7] 8.750375e+146 1.515799e+147 106.700 3.500150e+147 7.548618e+146 NaN 4.0 32.0 3.74
λ[8] 1.038159e+147 1.798369e+147 131.889 4.152637e+147 8.955808e+146 NaN 4.0 71.0 3.29
λ[9] 1.231690e+147 2.133616e+147 135.458 4.926759e+147 1.062532e+147 NaN 4.0 57.0 3.74
λ[10] 1.461298e+147 2.531358e+147 135.780 5.845190e+147 1.260606e+147 NaN 4.0 57.0 3.74
λ[11] 1.733708e+147 3.003246e+147 136.104 6.934832e+147 1.495604e+147 NaN 4.0 57.0 3.74
λ[12] 2.056901e+147 3.563102e+147 136.428 8.227603e+147 1.774410e+147 NaN 4.0 57.0 3.74

Sampler stats

print(post.sample_stats)
<xarray.Dataset> Size: 496kB
Dimensions:                (chain: 4, draw: 1000)
Coordinates:
  * chain                  (chain) int64 32B 0 1 2 3
  * draw                   (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999
Data variables: (12/17)
    acceptance_rate        (chain, draw) float64 32kB 1.0 1.0 ... 0.9937 0.9483
    diverging              (chain, draw) bool 4kB False False ... False False
    energy                 (chain, draw) float64 32kB 4.669e+148 ... 96.27
    energy_error           (chain, draw) float64 32kB 0.0 0.0 ... 0.02436
    index_in_trajectory    (chain, draw) int64 32kB -1 -1 1 -3 ... 2 4 1005 254
    largest_eigval         (chain, draw) float64 32kB nan nan nan ... nan nan
    ...                     ...
    process_time_diff      (chain, draw) float64 32kB 3e-05 3.3e-05 ... 0.01353
    reached_max_treedepth  (chain, draw) bool 4kB False False ... True True
    smallest_eigval        (chain, draw) float64 32kB nan nan nan ... nan nan
    step_size              (chain, draw) float64 32kB 8.018e-77 ... 0.001951
    step_size_bar          (chain, draw) float64 32kB 1.257e-93 ... 0.001808
    tree_depth             (chain, draw) int64 32kB 1 1 1 3 1 1 ... 2 4 4 10 10
Attributes:
    created_at:                 2025-04-11T03:14:58.407044+00:00
    arviz_version:              0.21.0
    inference_library:          pymc
    inference_library_version:  5.22.0
    sampling_time:              14.023865938186646
    tuning_steps:               1000

Tree depth

post.sample_stats["tree_depth"].values
array([[ 1,  1,  1,  3,  1,  1,  4,  2,  1,  1,  4,  1,  1,  1,  5,  1,  1,  1,  1,  3,  1,  4,  1,  1,  1,  1,  2,  4,  1,  1, ...,  2,  2,  1,  1,  1,  1,  1,  1,  4,  1,  1,  3,  1,  3,  1,  3,
         1,  4,  3,  1,  1,  1,  1,  1,  1,  1,  1,  2,  2,  3],
       [ 1,  5,  1,  1,  5,  5,  1,  5,  1,  4,  1,  1,  1,  1,  1,  1,  1,  2,  1,  1,  1,  1,  3,  2,  2,  2,  1,  2,  1,  1, ...,  1,  9,  4,  1,  1,  1,  1,  3,  1,  3,  2,  1,  2,  3,  1,  3,
         4,  1,  1,  2,  2,  2,  1,  2,  1,  4,  1,  1,  2,  1],
       [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, ..., 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
        10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10],
       [10, 10, 10, 10, 10, 10, 10,  2,  5, 10, 10, 10, 10, 10,  2, 10,  9,  2, 10, 10, 10, 10,  2, 10,  2,  2, 10, 10,  9, 10, ..., 10,  4,  8,  2,  3,  9, 10, 10, 10, 10, 10,  3,  3, 10, 10, 10,
         2,  9, 10,  3, 10,  8, 10,  2,  7,  2,  4,  4, 10, 10]], shape=(4, 1000))
post.sample_stats["reached_max_treedepth"].values
array([[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False,
        False, False, False, ..., False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False],
       [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False,
        False, False, False, ..., False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True, ...,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True],
       [False,  True,  True,  True,  True,  True,  True, False, False,  True,  True, False,  True,  True, False,  True, False, False,  True,  True, False,  True, False, False, False, False,  True,
        False, False,  True, ...,  True, False, False, False, False, False,  True,  True, False, False,  True, False, False,  True,  True,  True, False, False, False, False, False, False, False,
        False, False, False, False, False,  True,  True]], shape=(4, 1000))

Adjusting the sampler

with model:
  post = pm.sample(
    random_seed=1234,
    step = pm.NUTS(max_treedepth=20)
  )

Summary

az.summary(post)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
b[Intercept] -397.014 16.482 -430.085 -368.857 0.673 0.522 607.0 714.0 1.01
b[year] 0.202 0.008 0.188 0.219 0.000 0.000 607.0 714.0 1.01
λ[0] 28.354 2.120 24.271 32.130 0.083 0.064 676.0 839.0 1.01
λ[1] 34.685 2.322 30.344 38.973 0.089 0.067 693.0 886.0 1.01
λ[2] 42.433 2.516 37.650 47.018 0.094 0.070 719.0 935.0 1.01
λ[3] 51.915 2.691 46.682 56.655 0.098 0.071 758.0 986.0 1.00
λ[4] 63.520 2.838 57.726 68.380 0.099 0.069 825.0 1132.0 1.00
λ[5] 77.726 2.955 72.235 83.386 0.096 0.063 961.0 1248.0 1.00
λ[6] 95.114 3.056 89.294 100.866 0.086 0.056 1262.0 1574.0 1.00
λ[7] 116.401 3.203 110.884 123.106 0.072 0.051 1977.0 2219.0 1.00
λ[8] 142.461 3.547 135.449 148.765 0.063 0.054 3206.0 2894.0 1.00
λ[9] 174.367 4.340 166.502 182.479 0.077 0.065 3190.0 2869.0 1.00
λ[10] 213.435 5.871 202.877 224.795 0.132 0.088 1965.0 2545.0 1.00
λ[11] 261.273 8.389 245.983 277.402 0.236 0.140 1271.0 2056.0 1.00
λ[12] 319.856 12.146 297.797 343.323 0.387 0.247 989.0 1579.0 1.00

Trace plots

ax = az.plot_trace(post)
plt.show()

Trace plots (again)

ax = az.plot_trace(post.posterior["b"], compact=False)
plt.show()

Predictions (λ)

plt.figure(figsize=(12,6))
sns.scatterplot(x="year", y="cases", data=aids)
sns.lineplot(x="year", y=post.posterior["λ"].mean(dim=["chain", "draw"]), data=aids, color='red')
plt.show()

Revised model

y, X = patsy.dmatrices(
  "cases ~ year_min + year_min**2", 
  aids.assign(year_min = lambda x: x.year-np.min(x.year))
)

X_lab = X.design_info.column_names
y = np.asarray(y).flatten()
X = np.asarray(X)

with pm.Model(coords = {"coeffs": X_lab}) as model:
    b = pm.Cauchy("b", alpha=0, beta=1, dims="coeffs")
    η = X @ b
    λ = pm.Deterministic("λ", np.exp(η))
    
    likelihood = pm.Poisson("y", mu=λ, observed=y)
    
    post = pm.sample(random_seed=1234)

Summary

az.summary(post)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
b[Intercept] 3.337 0.071 3.215 3.475 0.002 0.002 981.0 958.0 1.0
b[year_min] 0.203 0.008 0.188 0.218 0.000 0.000 994.0 857.0 1.0
λ[0] 28.219 1.991 24.641 31.999 0.064 0.053 981.0 958.0 1.0
λ[1] 34.539 2.188 30.558 38.632 0.070 0.058 993.0 975.0 1.0
λ[2] 42.278 2.380 38.072 46.824 0.075 0.061 1012.0 1094.0 1.0
λ[3] 51.754 2.560 47.266 56.716 0.079 0.064 1045.0 1118.0 1.0
λ[4] 63.358 2.720 58.559 68.598 0.082 0.064 1100.0 1245.0 1.0
λ[5] 77.568 2.861 72.693 83.302 0.083 0.061 1202.0 1278.0 1.0
λ[6] 94.971 2.998 89.881 101.083 0.080 0.056 1416.0 1627.0 1.0
λ[7] 116.285 3.191 110.086 121.953 0.074 0.049 1879.0 2390.0 1.0
λ[8] 142.391 3.572 135.872 149.126 0.067 0.050 2878.0 2995.0 1.0
λ[9] 174.369 4.359 166.095 182.391 0.070 0.063 3813.0 2975.0 1.0
λ[10] 213.541 5.810 202.480 224.189 0.100 0.087 3386.0 2403.0 1.0
λ[11] 261.528 8.168 246.578 276.887 0.166 0.131 2417.0 2096.0 1.0
λ[12] 320.318 11.684 300.231 343.913 0.270 0.208 1885.0 1923.0 1.0

Trace plots

ax = az.plot_trace(post.posterior["b"], compact=False)
plt.show()

Predictions (λ)

plt.figure(figsize=(12,6))
sns.scatterplot(x="year", y="cases", data=aids)
sns.lineplot(x="year", y=post.posterior["λ"].mean(dim=["chain", "draw"]), data=aids, color='red')
plt.show()

Example 2 - Compound Samplers

Model with a discrete parameter

import pytensor

n = pytensor.shared(np.asarray([10, 20]))
with pm.Model() as m:
    p = pm.Beta("p", 1.0, 1.0)
    i = pm.Bernoulli("i", 0.5)
    k = pm.Binomial("k", p=p, n=n[i], observed=4)
    
    step = pm.CompoundStep([
      pm.NUTS([p]),
      pm.BinaryMetropolis([i])
    ])

    trace = pm.sample(
      1000, step=step
    )

Summary

az.summary(trace)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
i 0.332 0.471 0.000 1.00 0.015 0.005 969.0 969.0 1.0
p 0.359 0.154 0.095 0.64 0.005 0.002 859.0 1621.0 1.0

Trace plots

ax = az.plot_trace(trace)
plt.show()

d = pd.DataFrame({
  "p": trace.posterior["p"].values.flatten(),
  "i": trace.posterior["i"].values.flatten()
})
sns.displot(d, x="p", hue="i", kind="kde")
plt.show()

d.groupby("i").mean()
p
i
0 0.422838
1 0.230261

If we assume i=0: \[ \begin{aligned} p|x=4,i=0 \sim \text{Beta}(5,7) \\ E(p|x=4,i=0) = \frac{5}{5+7} = 0.416 \end{aligned} \]

If we assume i=1: \[ \begin{aligned} p|x=4,i=1 \sim \text{Beta}(5,17) \\ E(p|x=4,i=0) = \frac{5}{5+17} = 0.227 \end{aligned} \]